_base_ = [
    './swin-tiny-upernet_ade-pretrain.py'
]
checkpoint_file = './checkpoints/swin_base_patch4_window7_224_20220317-e9b98025.pth'  # noqa

training_steps = 80000

model = dict(
    backbone=dict(
        init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
        embed_dims=128,
        depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32]),
    decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
    auxiliary_head=dict(in_channels=512, num_classes=150))

param_scheduler = [
    dict(
        type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=8000),
    # dict(
    #     type='PolyLR',
    #     eta_min=0.0,
    #     power=1.0,
    #     begin=1500 // scale_factor,
    #     end=160000 // scale_factor,
    #     by_epoch=False,
    # )
    dict(
        type='CosineAnnealingLR',
        begin=8000,
        end=training_steps,
        eta_min=1e-7,
        by_epoch=False
    )
]

# By default, models are trained on 8 GPUs with 2 images per GPU

train_dataloader = dict(batch_size=16, num_workers=4)
val_dataloader = dict(batch_size=1)
test_dataloader = val_dataloader
train_cfg = dict(max_iters=training_steps, type='IterBasedTrainLoop', val_interval=4000)